#!/bin/bash
set -x
start_pserver() {
    stdbuf -oL paddle pserver \
      --use_gpu=0 \
      --port=$PADDLE_INIT_PORT \
      --ports_num=$PADDLE_INIT_PORTS_NUM \
      --ports_num_for_sparse=$PADDLE_INIT_PORTS_NUM_FOR_SPARSE \
      --nics=$PADDLE_INIT_NICS \
      --comment=paddle_process_k8s \
      --num_gradient_servers=$PADDLE_INIT_NUM_GRADIENT_SERVERS
}

start_new_pserver() {
  master_label="paddle-job-master=${PADDLE_JOB_NAME}"

  stdbuf -oL python /root/k8s_tools.py wait_pods_running  ${master_label} 1
  export MASTER_IP=$(python /root/k8s_tools.py fetch_ips ${master_label})
  stdbuf -oL /usr/bin/pserver \
    -port=$PADDLE_INIT_PORT \
    -num-pservers=$PSERVERS \
    -log-level=debug \
    -etcd-endpoint=http://$MASTER_IP:2379
}

start_master() {
  stdbuf -oL /usr/bin/master \
  -port=8080 \
  -chunk-per-task=1\
  -task-timout-dur=16s\
  -endpoints=http://127.0.0.1:2379
}

check_failed_cnt() {
  max_failed=$1
  failed_count=$(python /root/k8s_tools.py count_pods_by_phase paddle-job=${PADDLE_JOB_NAME} Failed) 
  if [ $failed_count -gt $max_failed ]; then
    stdbuf -oL echo "Failed trainer count beyond the threadhold: "$max_failed
    echo "Failed trainer count beyond the threshold: " $max_failed > /dev/termination-log 
    exit 0
  fi
}

check_trainer_ret() {
  ret=$1
  stdbuf -oL echo "job returned $ret...setting pod return message..."
  stdbuf -oL echo "==============================="

  if [ $ret -eq 136 ] ; then
    echo "Error Arithmetic Operation(Floating Point Exception)" > /dev/termination-log
  elif [ $ret -eq 139 ] ; then
    echo "Segmentation Fault" > /dev/termination-log
  elif [ $ret -eq 1 ] ; then
    echo "General Error" > /dev/termination-log
  elif [ $ret -eq 134 ] ; then
    echo "Program Abort" > /dev/termination-log
  fi
  stdbuf -oL echo "termination log wroted..."
  exit $ret
}

start_fluid_process() {
  pserver_label="paddle-job-pserver=${PADDLE_JOB_NAME}"
  trainer_label="paddle-job=${PADDLE_JOB_NAME}"
  hostname=${HOSTNAME}
  task_index=""

  if [ "${PADDLE_TRAINING_ROLE}" == "TRAINER" ] || [ "${PADDLE_TRAINING_ROLE}" == "PSERVER" ]; then
    stdbuf -oL python /root/k8s_tools.py wait_pods_running ${pserver_label} ${PADDLE_PSERVERS_NUM}
  fi

  if [ "${PADDLE_TRAINING_ROLE}" == "TRAINER" ] || [ "${PADDLE_TRAINING_ROLE}" == "WORKER" ]; then
    stdbuf -oL python /root/k8s_tools.py wait_pods_running ${trainer_label} ${PADDLE_TRAINERS_NUM}
  fi

  export PADDLE_PSERVERS=$(python /root/k8s_tools.py fetch_endpoints ${pserver_label} ${PADDLE_PORT})
  export PADDLE_TRAINER_IPS=$(python /root/k8s_tools.py fetch_ips ${trainer_label})

  if [ "${PADDLE_TRAINING_ROLE}" == "TRAINER" ] || [ "${PADDLE_TRAINING_ROLE}" == "WORKER" ]; then
    check_failed_cnt 1
    task_index=$(python /root/k8s_tools.py fetch_id ${trainer_label})
  else
    task_index=$(python /root/k8s_tools.py fetch_id ${pserver_label})
  fi
  
  export PADDLE_TRAINER_ID=${task_index}
  export PADDLE_PSERVER_ID=${task_index}

  stdbuf -oL sh -c "${ENTRY}"
  check_trainer_ret $?
}

# start_tf_benchmark_process is only used for benchmarking
start_tf_benchmark_process() {
  # re-use the paddle job labels
  pserver_label="paddle-job-pserver=${PADDLE_JOB_NAME}"
  trainer_label="paddle-job=${PADDLE_JOB_NAME}"
  task_index=""

  export PADDLE_INIT_PSERVERS=$(python /root/k8s_tools.py fetch_ips ${pserver_label} ${PADDLE_INIT_PORT})
  export PADDLE_WORKERS=$(python /root/k8s_tools.py fetch_ips ${trainer_label})
  export PADDLE_TRAINER_IPS=$(python /root/k8s_tools.py fetch_ips ${trainer_label})
  export TF_WORKER_EPS=$(python /root/k8s_tools.py fetch_ips ${trainer_label} ${TF_WORKER_PORT})

  if [ "${TRAINING_ROLE}" == "TRAINER" ]; then
    check_failed_cnt 1
    task_index=$(python /root/k8s_tools.py fetch_id ${trainer_label})
    export TF_ROLE=worker
  else
    task_index=$(python /root/k8s_tools.py fetch_id ${pserver_label})
    export TF_ROLE=ps
  fi
  
  export PADDLE_INIT_TRAINER_ID=${task_index}
  export PADDLE_TRAINER_ID=${task_index}

  stdbuf -oL sh -c "${ENTRY}"
  check_trainer_ret $?
}

start_new_trainer() {
  # FIXME(Yancey1989): use command-line interface to configure the max failed count
  check_failed_cnt ${TRAINERS}

  master_label="paddle-job-master=${PADDLE_JOB_NAME}"
  pserver_label="paddle-job-pserver=${PADDLE_JOB_NAME}"

  stdbuf -oL python /root/k8s_tools.py wait_pods_running ${pserver_label} ${PSERVERS}
  sleep 5
  stdbuf -oL python /root/k8s_tools.py wait_pods_running  ${master_label} 1
  export MASTER_IP=$(python /root/k8s_tools.py fetch_ips ${master_label})
  export ETCD_IP="$MASTER_IP"

  # NOTE: $TRAINER_PACKAGE may be large, do not copy
  export PYTHONPATH=$TRAINER_PACKAGE:$PYTHONPATH
  cd $TRAINER_PACKAGE

  stdbuf -oL echo "Starting training job: " $TRAINER_PACKAGE, "num_gradient_servers:" \
  $PADDLE_INIT_NUM_GRADIENT_SERVERS, "version: " $1 

  stdbuf -oL sh -c "${ENTRY}"
  check_trainer_ret $?
}

start_trainer() {
    # paddle v1 and V2 distributed training does not allow any trainer failed. 
    check_failed_cnt 0

    pserver_label="paddle-job-pserver=${PADDLE_JOB_NAME}"
    trainer_label="paddle-job=${PADDLE_JOB_NAME}"

    stdbuf -oL python /root/k8s_tools.py wait_pods_running ${pserver_label} ${PSERVERS}
    stdbuf -oL python /root/k8s_tools.py wait_pods_running ${trainer_label} ${TRAINERS}

    export PADDLE_INIT_PSERVERS=$(python /root/k8s_tools.py fetch_ips ${pserver_label})
    export PADDLE_INIT_TRAINER_ID=$(python /root/k8s_tools.py fetch_id ${trainer_label})
    stdbuf -oL echo $PADDLE_INIT_TRAINER_ID > /trainer_id
    # FIXME: /trainer_count = PADDLE_INIT_NUM_GRADIENT_SERVERS
    stdbuf -oL echo $PADDLE_INIT_NUM_GRADIENT_SERVERS > /trainer_count

    # NOTE: $TRAINER_PACKAGE may be large, do not copy
    export PYTHONPATH=$TRAINER_PACKAGE:$PYTHONPATH
    cd $TRAINER_PACKAGE

    stdbuf -oL echo "Starting training job: " $TRAINER_PACKAGE, "num_gradient_servers:" \
    $PADDLE_INIT_NUM_GRADIENT_SERVERS, "trainer_id: " $PADDLE_INIT_TRAINER_ID, \

    export version="v2"
    if [[ -z $1 ]]; then
      "didn't specified a version, use the default: " $version
    else
      export version="$1"
      "user specified version: " $version
    fi

    # FIXME: If we use the new PServer by Golang, add Kubernetes healthz
    # to wait PServer process get ready.Now only sleep 20 seconds.
    sleep 20

    case "$version" in
      "v1")
        FILE_COUNT=$(wc -l $TRAIN_LIST | awk '{print $1}')
        if [ $FILE_COUNT -le $PADDLE_INIT_NUM_GRADIENT_SERVERS ]; then
          echo "file count less than trainers"
          check_trainer_ret 0
        fi
        let lines_per_node="$FILE_COUNT / ($PADDLE_INIT_NUM_GRADIENT_SERVERS + 1)"
        echo "spliting file to" $lines_per_node
        cp $TRAIN_LIST /
        cd /
        split -l $lines_per_node -d -a 3 $TRAIN_LIST train.list
        CURRENT_LIST=$(printf "train.list%03d" $PADDLE_INIT_TRAINER_ID)
        # always use /train.list for paddle v1 for each node.
        echo "File for current node ${CURRENT_LIST}"
        sleep 10
        cp $CURRENT_LIST train.list

        cd $TRAINER_PACKAGE

        stdbuf -oL  paddle train \
          --port=$PADDLE_INIT_PORT \
          --nics=$PADDLE_INIT_NICS \
          --ports_num=$PADDLE_INIT_PORTS_NUM \
          --ports_num_for_sparse=$PADDLE_INIT_PORTS_NUM_FOR_SPARSE \
          --num_passes=$PADDLE_INIT_NUM_PASSES \
          --trainer_count=$PADDLE_INIT_TRAINER_COUNT \
          --saving_period=1 \
          --log_period=20 \
          --local=0 \
          --rdma_tcp=tcp \
          --config=$TOPOLOGY \
          --use_gpu=$PADDLE_INIT_USE_GPU \
          --trainer_id=$PADDLE_INIT_TRAINER_ID \
          --save_dir=$OUTPUT \
          --pservers=$PADDLE_INIT_PSERVERS \
          --num_gradient_servers=$PADDLE_INIT_NUM_GRADIENT_SERVERS
        # paddle v1 API does not allow any trainer failed.
        check_trainer_ret $? 
        ;;
      "v2")
        stdbuf -oL sh -c "${ENTRY}"
        # paddle v2 API does not allow any trainer failed.
        check_trainer_ret $? 
        ;;
      *)
        ;;
    esac
}

usage() {
    echo "usage: paddle_k8s [<args>]:"
    echo "  start_trainer  [v1|v2]    Start a trainer process with v1 or v2 API"
    echo "  start_pserver             Start a pserver process"
    echo "  start_new_pserver         Start a new pserver process"
    echo "  start_new_trainer         Start a new triner process"
}

case "$1" in
    start_pserver)
        start_pserver
        ;;
    start_trainer)
        start_trainer $2
        ;;
    start_new_trainer)
        start_new_trainer
        ;;
    start_new_pserver)
        start_new_pserver
        ;;
    start_master)
        start_master
        ;;
    start_fluid)
        start_fluid_process
        ;;
    --help)
        usage
        ;;
    *)
        usage
        ;;
esac
